Pixel explanations for a deep neural network diabetic retinopathy classification model

Let's import some libraries required for the notebook to work:

In [1]:

Let's define some constants related with the model:

In [2]:

And now, let's load the test dataset:

In [3]:

Let's load the previously trained model:

In [4]:
=> loading checkpoint 'models/ret6_bn/model_best-QWKval0814.pth.tar'
=> loaded checkpoint 'models/ret6_bn/model_best-QWKval0814.pth.tar' (epoch 638)

We instantiate the previously defined model with the parameters of the pretrained one loaded before:

In [5]:
Out[5]:
model_explainable(
  (rf3): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (rf5): Sequential(
    (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (rf9): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (rf13): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (rf21): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (rf29): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (rf45): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (rf61): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (rf93): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (rf125): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (rf189): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (rf253): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (rf381): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (rf509): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (rf637): Sequential(
    (0): Conv2d(64, 64, kernel_size=(2, 2), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (lc): Sequential(
    (0): Linear(in_features=64, out_features=5, bias=True)
  )
)

Now, we load a image from the dataset. This image is the one that will be analyzed:

In [49]:

Let's visualize the image:

In [50]:

Let's calculate with the model the predicted classification class. We also calculate the intermediate values of the activations of every layer

In [51]:
Image nr:9854
File: ../input/640_test/4/23643_right.jpeg
Image name: 23643_right
Tag
Output before softmax: 
tensor([[-184.3373, -126.4893, -104.0732,  -44.3767,   97.2942]],
       grad_fn=<AddmmBackward>)
Predicted class: tensor(4)
Target: 
tensor([4])
In [52]:
4

Last layer feature activations

And now we plot the feature activations in the last layer of the model previous to the output layer:

In [53]:

We create a variable to sum the constants contributions of every layer mapped into the input space. Finally we will add this value to the values calculated for the input space

In [54]:
In [55]:

Score propagation through average pooling layer:

In [56]:
Max: 5.114893913269043, Min: -3.349405288696289, Avg: -0.07076960057020187, Std:0.3903641402721405
torch.Size([5, 64, 4, 4])
Out[56]:
(tensor(5.1149), tensor(-3.3494), tensor(-0.0708), tensor(0.3904))

Scores of the individual feature maps of the 15th layer

In [57]:

Aggregate feature maps scores of the 15th layer

In [58]:

Aggregated scores of 15th layer mapped into pixel space using a 2d-normal 2std=rf prior

In [59]:

Score propagation through classification layer (2x2 convolution)

In [60]:
Max: 4.764776229858398, Min: -7.107814788818359, Avg: -0.06915958970785141, Std:0.6068668961524963
torch.Size([5, 64, 5, 5])
Out[60]:
(tensor(4.7648), tensor(-7.1078), tensor(-0.0692), tensor(0.6069))
In [61]:
torch.Size([5, 1, 4, 4])
In [62]:

In the previous map, every pixel corresponds to a receptive field. There is superposition between the activations showed in the above map. We now from bibliography that the receptive fields have a nearly gaussian detection capability. Considering a 2 stdev gaussian model over the receptive field, we can map the activations to input space. The result is shown below:

In [63]:
Max: 0.0009453239617869258, Min: -0.0005918593378737569, Avg: 9.323054837295786e-05, Std:0.00036903333966620266
torch.Size([5, 1, 640, 640])

Scores of the feature maps of the 14th layer

In [64]:

Aggregated feature map scores of the 14th layer

In [65]:

Aggregated feature map scores of the 14th layer mapped into input space

In [66]:
In [67]:
Max: 7.961204528808594, Min: -7.503911018371582, Avg: -0.0062023852951824665, Std:0.3135142922401428
torch.Size([5, 64, 10, 10])
Max: 9.753778457641602, Min: -10.782066345214844, Avg: -0.011087514460086823, Std:0.31653109192848206
torch.Size([5, 64, 10, 10])
Out[67]:
(tensor(9.7538), tensor(-10.7821), tensor(-0.0111), tensor(0.3165))
In [68]:
torch.Size([5, 1, 10, 10])
In [69]:
Max: 0.0004309612268116325, Min: -0.0012467608321458101, Avg: -0.00017324268992524594, Std:0.00031734094955027103
torch.Size([5, 1, 640, 640])
In [70]:
In [71]:
In [72]:
In [73]:
Max: 5.352302551269531, Min: -5.8788557052612305, Avg: -0.018513550981879234, Std:0.3383949398994446
torch.Size([5, 64, 10, 10])
Out[73]:
(tensor(5.3523), tensor(-5.8789), tensor(-0.0185), tensor(0.3384))
In [74]:
torch.Size([5, 1, 10, 10])
In [75]:
Max: 0.0011963018914684653, Min: -0.000941299251280725, Avg: 0.00019236220396123827, Std:0.00031793309608474374
torch.Size([5, 1, 640, 640])
In [76]:
In [77]:
In [78]:
In [79]:
Max: 5.164151191711426, Min: -5.989358425140381, Avg: -0.003520882222801447, Std:0.13551212847232819
torch.Size([5, 64, 20, 20])
Out[79]:
(tensor(5.1642), tensor(-5.9894), tensor(-0.0035), tensor(0.1355))
In [80]:
In [ ]:
In [81]:
Max: 0.0015183317009359598, Min: -0.0020616967231035233, Avg: -6.921908789081499e-05, Std:0.0004113923932891339
torch.Size([5, 1, 640, 640])
In [82]:
In [83]:
In [84]:
In [85]:
Max: 8.443958282470703, Min: -11.712089538574219, Avg: -0.0026857329066842794, Std:0.21393094956874847
torch.Size([5, 64, 20, 20])
Out[85]:
(tensor(8.4440), tensor(-11.7121), tensor(-0.0027), tensor(0.2139))
In [86]:
In [87]:
Max: 0.003119109896942973, Min: -0.00269049359485507, Avg: -5.219684317125939e-05, Std:0.000684752652887255
torch.Size([5, 1, 640, 640])
In [88]:
In [89]:
In [90]:
In [91]:
Max: 6.2140045166015625, Min: -12.704032897949219, Avg: -0.0007335206028074026, Std:0.08428094536066055
torch.Size([5, 64, 40, 40])
Out[91]:
(tensor(6.2140), tensor(-12.7040), tensor(-0.0007), tensor(0.0843))
In [92]:
In [93]:
Max: 0.012140105478465557, Min: -0.013842101208865643, Avg: 1.5521731256740168e-05, Std:0.0020233530085533857
torch.Size([5, 1, 640, 640])
In [94]:
In [95]:
In [96]:
In [97]:
Max: 17.412944793701172, Min: -17.72927474975586, Avg: -0.0006762614357285202, Std:0.11676545441150665
torch.Size([5, 64, 40, 40])
Out[97]:
(tensor(17.4129), tensor(-17.7293), tensor(-0.0007), tensor(0.1168))
In [98]:
In [99]:
Max: 0.01862882263958454, Min: -0.013836498372256756, Avg: -1.431467717338819e-05, Std:0.0023974161595106125
torch.Size([5, 1, 640, 640])
In [100]:
In [101]:
In [102]:
In [103]:
Max: 4.04331111907959, Min: -11.292847633361816, Avg: -7.783865294186398e-05, Std:0.05109566077589989
torch.Size([5, 64, 80, 80])
Out[103]:
(tensor(4.0433), tensor(-11.2928), tensor(-7.7839e-05), tensor(0.0511))
In [104]:
torch.Size([5, 1, 80, 80])
In [105]:
Max: 0.014074004255235195, Min: -0.017603276297450066, Avg: -9.122679330175743e-05, Std:0.0016115723410621285
torch.Size([5, 1, 640, 640])
In [106]:
In [107]:
In [108]:
In [109]:
Max: 8.34609603881836, Min: -17.917417526245117, Avg: -0.00011916153016500175, Std:0.06798889487981796
torch.Size([5, 64, 80, 80])
Out[109]:
(tensor(8.3461), tensor(-17.9174), tensor(-0.0001), tensor(0.0680))
In [110]:
torch.Size([5, 1, 80, 80])
In [111]:
Max: 0.033730264753103256, Min: -0.043616391718387604, Avg: 4.1322804463561624e-05, Std:0.0032907805871218443
torch.Size([5, 1, 640, 640])
In [112]:
In [113]:
In [114]:
In [115]:
In [116]:
Max: 5.073758602142334, Min: -11.340737342834473, Avg: -1.3102913726470433e-05, Std:0.027750037610530853
torch.Size([5, 64, 160, 160])
Out[116]:
(tensor(5.0738), tensor(-11.3407), tensor(-1.3103e-05), tensor(0.0278))
In [117]:
torch.Size([5, 1, 160, 160])
In [118]:
Max: 0.03664026036858559, Min: -0.05075312778353691, Avg: -6.675082113360986e-05, Std:0.0027678257320076227
torch.Size([5, 1, 640, 640])
In [119]:
In [120]:
In [121]:
In [122]:
Max: 5.852373123168945, Min: -24.580745697021484, Avg: -4.305892434786074e-05, Std:0.07713824510574341
torch.Size([5, 32, 160, 160])
Out[122]:
(tensor(5.8524), tensor(-24.5807), tensor(-4.3059e-05), tensor(0.0771))
In [123]:
torch.Size([5, 1, 160, 160])
In [124]:
Max: 0.05444978550076485, Min: -0.22009994089603424, Avg: 3.370633930899203e-05, Std:0.004536021500825882
torch.Size([5, 1, 640, 640])
In [125]:
In [126]:
In [127]:
In [128]:
Max: 7.03353214263916, Min: -18.952346801757812, Avg: -8.108151632768568e-06, Std:0.0269682127982378
torch.Size([5, 32, 320, 320])
Out[128]:
(tensor(7.0335), tensor(-18.9523), tensor(-8.1082e-06), tensor(0.0270))
In [129]:
torch.Size([5, 1, 320, 320])
In [130]:
Max: 0.4995645582675934, Min: -1.4880867004394531, Avg: -2.125450555467978e-05, Std:0.019385717809200287
torch.Size([5, 1, 640, 640])
In [131]:
In [132]:
In [133]:
In [134]:
Max: 43.391136169433594, Min: -46.139102935791016, Avg: -3.980544352089055e-05, Std:0.12974578142166138
torch.Size([5, 16, 320, 320])
Out[134]:
(tensor(43.3911), tensor(-46.1391), tensor(-3.9805e-05), tensor(0.1297))
In [135]:
torch.Size([5, 1, 320, 320])
In [136]:
Max: 1.015857458114624, Min: -2.311088800430298, Avg: 9.436165419174358e-05, Std:0.024955159053206444
torch.Size([5, 1, 640, 640])
In [137]:
In [138]:
In [139]:
In [140]:
In [141]:
Max: 28.672992706298828, Min: -32.63753128051758, Avg: -1.0890984412981197e-05, Std:0.06722572445869446
torch.Size([5, 16, 640, 640])
Out[141]:
(tensor(28.6730), tensor(-32.6375), tensor(-1.0891e-05), tensor(0.0672))
In [142]:
torch.Size([5, 1, 640, 640])
In [143]:
Max: 13.400936126708984, Min: -21.839786529541016, Avg: 1.5029602764116134e-05, Std:0.1516730785369873
torch.Size([5, 1, 640, 640])
In [144]:
In [145]:
In [146]:
In [147]:
In [148]:
Max: 117.44732666015625, Min: -37.60302734375, Avg: -3.488934089546092e-05, Std:0.39204826951026917
torch.Size([5, 3, 640, 640])
In [149]:
torch.Size([5, 1, 640, 640])
In [150]:
Max: 25.520057678222656, Min: -28.88320541381836, Avg: -6.958143058000132e-05, Std:0.1578371524810791
torch.Size([5, 1, 640, 640])

Score Input

In [151]:
In [152]:
mean=0.00023709492234047502 std=1.056968092918396
In [154]:
Out[154]:
4
In [155]:
Out[155]:
torch.Size([1, 5, 640, 640])
In [156]:
In [163]:
In [ ]:
In [ ]: